use crate::{StageCheckpoint, StageId};
use alloy_primitives::{BlockHash, BlockNumber};
use futures_util::{Stream, StreamExt};
use reqwest::{Client, Url};
use reth_config::config::EtlConfig;
use reth_db_api::{table::Value, transaction::DbTxMut};
use reth_era::era1_file::Era1Reader;
use reth_era_downloader::{read_dir, EraClient, EraMeta, EraStream, EraStreamConfig};
use reth_era_utils as era;
use reth_etl::Collector;
use reth_primitives_traits::{FullBlockBody, FullBlockHeader, NodePrimitives};
use reth_provider::{
    BlockReader, BlockWriter, DBProvider, HeaderProvider, StageCheckpointWriter,
    StaticFileProviderFactory, StaticFileWriter,
};
use reth_stages_api::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput};
use reth_static_file_types::StaticFileSegment;
use reth_storage_errors::ProviderError;
use std::{
    fmt::{Debug, Formatter},
    iter,
    path::Path,
    task::{ready, Context, Poll},
};

type Item<Header, Body> =
    Box<dyn Iterator<Item = eyre::Result<(Header, Body)>> + Send + Sync + Unpin>;
type ThreadSafeEraStream<Header, Body> =
    Box<dyn Stream<Item = eyre::Result<Item<Header, Body>>> + Send + Sync + Unpin>;

/// The [ERA1](https://github.com/eth-clients/e2store-format-specs/blob/main/formats/era1.md)
/// pre-merge history stage.
///
/// Imports block headers and bodies from genesis up to the last pre-merge block. Receipts are
/// generated by execution. Execution is not done in this stage.
pub struct EraStage<Header, Body, StreamFactory> {
    /// The `source` creates `stream`.
    source: Option<StreamFactory>,
    /// A map of block hash to block height collected when processing headers and inserted into
    /// database afterward.
    hash_collector: Collector<BlockHash, BlockNumber>,
    /// Last extracted iterator of block `Header` and `Body` pairs.
    item: Option<Item<Header, Body>>,
    /// A stream of [`Item`]s, i.e. iterators over block `Header` and `Body` pairs.
    stream: Option<ThreadSafeEraStream<Header, Body>>,
}

trait EraStreamFactory<Header, Body> {
    fn create(self, input: ExecInput) -> Result<ThreadSafeEraStream<Header, Body>, StageError>;
}

impl<Header, Body> EraStreamFactory<Header, Body> for EraImportSource
where
    Header: FullBlockHeader + Value,
    Body: FullBlockBody<OmmerHeader = Header>,
{
    fn create(self, input: ExecInput) -> Result<ThreadSafeEraStream<Header, Body>, StageError> {
        match self {
            Self::Path(path) => Self::convert(
                read_dir(path, input.next_block()).map_err(|e| StageError::Fatal(e.into()))?,
            ),
            Self::Url(url, folder) => {
                let _ = reth_fs_util::create_dir_all(&folder);
                let client = EraClient::new(Client::new(), url, folder);

                Self::convert(EraStream::new(
                    client,
                    EraStreamConfig::default().start_from(input.next_block()),
                ))
            }
        }
    }
}

impl EraImportSource {
    fn convert<Header, Body>(
        stream: impl Stream<Item = eyre::Result<impl EraMeta + Send + Sync + 'static + Unpin>>
            + Send
            + Sync
            + 'static
            + Unpin,
    ) -> Result<ThreadSafeEraStream<Header, Body>, StageError>
    where
        Header: FullBlockHeader + Value,
        Body: FullBlockBody<OmmerHeader = Header>,
    {
        Ok(Box::new(Box::pin(stream.map(|meta| {
            meta.and_then(|meta| {
                let file = reth_fs_util::open(meta.path())?;
                let reader = Era1Reader::new(file);
                let iter = reader.iter();
                let iter = iter.map(era::decode);
                let iter = iter.chain(
                    iter::once_with(move || match meta.mark_as_processed() {
                        Ok(..) => None,
                        Err(e) => Some(Err(e)),
                    })
                    .flatten(),
                );

                Ok(Box::new(iter) as Item<Header, Body>)
            })
        }))))
    }
}

impl<Header: Debug, Body: Debug, F: Debug> Debug for EraStage<Header, Body, F> {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("EraStage")
            .field("source", &self.source)
            .field("hash_collector", &self.hash_collector)
            .field("item", &self.item.is_some())
            .field("stream", &"dyn Stream")
            .finish()
    }
}

impl<Header, Body, F> EraStage<Header, Body, F> {
    /// Creates a new [`EraStage`].
    pub fn new(source: Option<F>, etl_config: EtlConfig) -> Self {
        Self {
            source,
            item: None,
            stream: None,
            hash_collector: Collector::new(etl_config.file_size, etl_config.dir),
        }
    }
}

impl<Provider, N, F> Stage<Provider> for EraStage<N::BlockHeader, N::BlockBody, F>
where
    Provider: DBProvider<Tx: DbTxMut>
        + StaticFileProviderFactory<Primitives = N>
        + BlockWriter<Block = N::Block>
        + BlockReader<Block = N::Block>
        + StageCheckpointWriter,
    F: EraStreamFactory<N::BlockHeader, N::BlockBody> + Send + Sync + Clone,
    N: NodePrimitives<BlockHeader: Value>,
{
    fn id(&self) -> StageId {
        StageId::Era
    }

    fn poll_execute_ready(
        &mut self,
        cx: &mut Context<'_>,
        input: ExecInput,
    ) -> Poll<Result<(), StageError>> {
        if input.target_reached() || self.item.is_some() {
            return Poll::Ready(Ok(()));
        }

        if self.stream.is_none() {
            if let Some(source) = self.source.clone() {
                self.stream.replace(source.create(input)?);
            }
        }
        if let Some(stream) = &mut self.stream {
            if let Some(next) = ready!(stream.poll_next_unpin(cx))
                .transpose()
                .map_err(|e| StageError::Fatal(e.into()))?
            {
                self.item.replace(next);
            }
        }

        Poll::Ready(Ok(()))
    }

    fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result<ExecOutput, StageError> {
        let height = if let Some(era) = self.item.take() {
            let static_file_provider = provider.static_file_provider();

            // Consistency check of expected headers in static files vs DB is done on
            // provider::sync_gap when poll_execute_ready is polled.
            let last_header_number = static_file_provider
                .get_highest_static_file_block(StaticFileSegment::Headers)
                .unwrap_or_default();

            // Find the latest total difficulty
            let mut td = static_file_provider
                .header_td_by_number(last_header_number)?
                .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))?;

            // Although headers were downloaded in reverse order, the collector iterates it in
            // ascending order
            let mut writer = static_file_provider.latest_writer(StaticFileSegment::Headers)?;

            let height = era::process_iter(
                era,
                &mut writer,
                provider,
                &mut self.hash_collector,
                &mut td,
                last_header_number..=input.target(),
            )
            .map_err(|e| StageError::Fatal(e.into()))?;

            if !self.hash_collector.is_empty() {
                era::build_index(provider, &mut self.hash_collector)
                    .map_err(|e| StageError::Recoverable(e.into()))?;
                self.hash_collector.clear();
            }

            era::save_stage_checkpoints(
                &provider,
                input.checkpoint().block_number,
                height,
                height,
                input.target(),
            )?;

            height
        } else {
            input.target()
        };

        Ok(ExecOutput { checkpoint: StageCheckpoint::new(height), done: height == input.target() })
    }

    fn unwind(
        &mut self,
        _provider: &Provider,
        input: UnwindInput,
    ) -> Result<UnwindOutput, StageError> {
        Ok(UnwindOutput { checkpoint: input.checkpoint.with_block_number(input.unwind_to) })
    }
}

/// Describes where to get the era files from.
#[derive(Debug, Clone)]
pub enum EraImportSource {
    /// Remote HTTP accessible host.
    Url(Url, Box<Path>),
    /// Local directory.
    Path(Box<Path>),
}

impl EraImportSource {
    /// Maybe constructs a new `EraImportSource` depending on the arguments.
    ///
    /// Only one of `url` or `path` should be provided, but upholding this invariant is delegated
    /// above so that both parameters can be accepted.
    ///
    /// # Arguments
    /// * The `path` uses a directory as the import source. It and its contents must be readable.
    /// * The `url` uses an HTTP client to list and download files.
    /// * The `default` gives the default [`Url`] if none of the previous parameters are provided.
    /// * For any [`Url`] the `folder` is used as the download directory for storing files
    ///   temporarily. It and its contents must be readable and writable.
    pub fn maybe_new(
        path: Option<Box<Path>>,
        url: Option<Url>,
        default: impl FnOnce() -> Option<Url>,
        folder: impl FnOnce() -> Box<Path>,
    ) -> Option<Self> {
        path.map(Self::Path).or_else(|| url.or_else(default).map(|url| Self::Url(url, folder())))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::test_utils::{
        stage_test_suite, ExecuteStageTestRunner, StageTestRunner, UnwindStageTestRunner,
    };
    use alloy_primitives::B256;
    use assert_matches::assert_matches;
    use reth_db_api::tables;
    use reth_provider::BlockHashReader;
    use reth_testing_utils::generators::{self, random_header};
    use test_runner::EraTestRunner;

    #[tokio::test]
    async fn test_era_range_ends_below_target() {
        let era_cap = 2;
        let target = 20000;

        let mut runner = EraTestRunner::default();

        let input = ExecInput { target: Some(era_cap), checkpoint: None };
        runner.seed_execution(input).unwrap();

        let input = ExecInput { target: Some(target), checkpoint: None };
        let output = runner.execute(input).await.unwrap();

        runner.commit();

        assert_matches!(
            output,
            Ok(ExecOutput {
                checkpoint: StageCheckpoint { block_number, stage_checkpoint: None },
                done: false
            }) if block_number == era_cap
        );

        let output = output.unwrap();
        let validation_output = runner.validate_execution(input, Some(output.clone()));

        assert_matches!(validation_output, Ok(()));

        runner.take_responses();

        let input = ExecInput { target: Some(target), checkpoint: Some(output.checkpoint) };
        let output = runner.execute(input).await.unwrap();

        runner.commit();

        assert_matches!(
            output,
            Ok(ExecOutput {
                checkpoint: StageCheckpoint { block_number, stage_checkpoint: None },
                done: true
            }) if block_number == target
        );

        let validation_output = runner.validate_execution(input, output.ok());

        assert_matches!(validation_output, Ok(()));
    }

    mod test_runner {
        use super::*;
        use crate::test_utils::{TestRunnerError, TestStageDB};
        use alloy_consensus::{BlockBody, Header};
        use futures_util::stream;
        use reth_db_api::{
            cursor::DbCursorRO,
            models::{StoredBlockBodyIndices, StoredBlockOmmers},
            transaction::DbTx,
        };
        use reth_ethereum_primitives::TransactionSigned;
        use reth_primitives_traits::{SealedBlock, SealedHeader};
        use reth_provider::{BlockNumReader, TransactionsProvider};
        use reth_testing_utils::generators::{
            random_block_range, random_signed_tx, BlockRangeParams,
        };
        use tokio::sync::watch;

        pub(crate) struct EraTestRunner {
            channel: (watch::Sender<B256>, watch::Receiver<B256>),
            db: TestStageDB,
            responses: Option<Vec<(Header, BlockBody<TransactionSigned>)>>,
        }

        impl Default for EraTestRunner {
            fn default() -> Self {
                Self {
                    channel: watch::channel(B256::ZERO),
                    db: TestStageDB::default(),
                    responses: Default::default(),
                }
            }
        }

        impl StageTestRunner for EraTestRunner {
            type S = EraStage<Header, BlockBody<TransactionSigned>, StubResponses>;

            fn db(&self) -> &TestStageDB {
                &self.db
            }

            fn stage(&self) -> Self::S {
                EraStage::new(self.responses.clone().map(StubResponses), EtlConfig::default())
            }
        }

        impl ExecuteStageTestRunner for EraTestRunner {
            type Seed = Vec<SealedBlock<reth_ethereum_primitives::Block>>;

            fn seed_execution(&mut self, input: ExecInput) -> Result<Self::Seed, TestRunnerError> {
                let start = input.checkpoint().block_number;
                let end = input.target();

                let static_file_provider = self.db.factory.static_file_provider();

                let mut rng = generators::rng();

                // Static files do not support gaps in headers, so we need to generate 0 to end
                let blocks = random_block_range(
                    &mut rng,
                    0..=end,
                    BlockRangeParams {
                        parent: Some(B256::ZERO),
                        tx_count: 0..2,
                        ..Default::default()
                    },
                );
                self.db.insert_headers_with_td(blocks.iter().map(|block| block.sealed_header()))?;
                if let Some(progress) = blocks.get(start as usize) {
                    // Insert last progress data
                    {
                        let tx = self.db.factory.provider_rw()?.into_tx();
                        let mut static_file_producer = static_file_provider
                            .get_writer(start, StaticFileSegment::Transactions)?;

                        let body = StoredBlockBodyIndices {
                            first_tx_num: 0,
                            tx_count: progress.transaction_count() as u64,
                        };

                        static_file_producer.set_block_range(0..=progress.number);

                        body.tx_num_range().try_for_each(|tx_num| {
                            let transaction = random_signed_tx(&mut rng);
                            static_file_producer.append_transaction(tx_num, &transaction).map(drop)
                        })?;

                        if body.tx_count != 0 {
                            tx.put::<tables::TransactionBlocks>(
                                body.last_tx_num(),
                                progress.number,
                            )?;
                        }

                        tx.put::<tables::BlockBodyIndices>(progress.number, body)?;

                        if !progress.ommers_hash_is_empty() {
                            tx.put::<tables::BlockOmmers>(
                                progress.number,
                                StoredBlockOmmers { ommers: progress.body().ommers.clone() },
                            )?;
                        }

                        static_file_producer.commit()?;
                        tx.commit()?;
                    }
                }
                self.responses.replace(
                    blocks.iter().map(|v| (v.header().clone(), v.body().clone())).collect(),
                );
                Ok(blocks)
            }

            /// Validate stored headers and bodies
            fn validate_execution(
                &self,
                input: ExecInput,
                output: Option<ExecOutput>,
            ) -> Result<(), TestRunnerError> {
                let initial_checkpoint = input.checkpoint().block_number;
                match output {
                    Some(output) if output.checkpoint.block_number > initial_checkpoint => {
                        let provider = self.db.factory.provider()?;
                        let mut td = provider
                            .header_td_by_number(initial_checkpoint.saturating_sub(1))?
                            .unwrap_or_default();

                        for block_num in initial_checkpoint..
                            output
                                .checkpoint
                                .block_number
                                .min(self.responses.as_ref().map(|v| v.len()).unwrap_or_default()
                                    as BlockNumber)
                        {
                            // look up the header hash
                            let hash = provider.block_hash(block_num)?.expect("no header hash");

                            // validate the header number
                            assert_eq!(provider.block_number(hash)?, Some(block_num));

                            // validate the header
                            let header = provider.header_by_number(block_num)?;
                            assert!(header.is_some());
                            let header = SealedHeader::seal_slow(header.unwrap());
                            assert_eq!(header.hash(), hash);

                            // validate the header total difficulty
                            td += header.difficulty;
                            assert_eq!(provider.header_td_by_number(block_num)?, Some(td));
                        }

                        self.validate_db_blocks(
                            output.checkpoint.block_number,
                            output.checkpoint.block_number,
                        )?;
                    }
                    _ => self.check_no_header_entry_above(initial_checkpoint)?,
                };
                Ok(())
            }

            async fn after_execution(&self, headers: Self::Seed) -> Result<(), TestRunnerError> {
                let tip = if headers.is_empty() {
                    let tip = random_header(&mut generators::rng(), 0, None);
                    self.db.insert_headers(iter::once(&tip))?;
                    tip.hash()
                } else {
                    headers.last().unwrap().hash()
                };
                self.send_tip(tip);
                Ok(())
            }
        }

        impl UnwindStageTestRunner for EraTestRunner {
            fn validate_unwind(&self, _input: UnwindInput) -> Result<(), TestRunnerError> {
                Ok(())
            }
        }

        impl EraTestRunner {
            pub(crate) fn check_no_header_entry_above(
                &self,
                block: BlockNumber,
            ) -> Result<(), TestRunnerError> {
                self.db
                    .ensure_no_entry_above_by_value::<tables::HeaderNumbers, _>(block, |val| val)?;
                self.db.ensure_no_entry_above::<tables::CanonicalHeaders, _>(block, |key| key)?;
                self.db.ensure_no_entry_above::<tables::Headers, _>(block, |key| key)?;
                self.db.ensure_no_entry_above::<tables::HeaderTerminalDifficulties, _>(
                    block,
                    |num| num,
                )?;
                Ok(())
            }

            pub(crate) fn send_tip(&self, tip: B256) {
                self.channel.0.send(tip).expect("failed to send tip");
            }

            /// Validate that the inserted block data is valid
            pub(crate) fn validate_db_blocks(
                &self,
                prev_progress: BlockNumber,
                highest_block: BlockNumber,
            ) -> Result<(), TestRunnerError> {
                let static_file_provider = self.db.factory.static_file_provider();

                self.db.query(|tx| {
                    // Acquire cursors on body related tables
                    let mut bodies_cursor = tx.cursor_read::<tables::BlockBodyIndices>()?;
                    let mut ommers_cursor = tx.cursor_read::<tables::BlockOmmers>()?;
                    let mut tx_block_cursor = tx.cursor_read::<tables::TransactionBlocks>()?;

                    let first_body_key = match bodies_cursor.first()? {
                        Some((key, _)) => key,
                        None => return Ok(()),
                    };

                    let mut prev_number: Option<BlockNumber> = None;


                    for entry in bodies_cursor.walk(Some(first_body_key))? {
                        let (number, body) = entry?;

                        // Validate sequentiality only after prev progress,
                        // since the data before is mocked and can contain gaps
                        if number > prev_progress {
                            if let Some(prev_key) = prev_number {
                                assert_eq!(prev_key + 1, number, "Body entries must be sequential");
                            }
                        }

                        // Validate that the current entry is below or equals to the highest allowed block
                        assert!(
                            number <= highest_block,
                            "We wrote a block body outside of our synced range. Found block with number {number}, highest block according to stage is {highest_block}",
                        );

                        let header = static_file_provider.header_by_number(number)?.expect("to be present");
                        // Validate that ommers exist if any
                        let stored_ommers =  ommers_cursor.seek_exact(number)?;
                        if header.ommers_hash_is_empty() {
                            assert!(stored_ommers.is_none(), "Unexpected ommers entry");
                        } else {
                            assert!(stored_ommers.is_some(), "Missing ommers entry");
                        }

                        let tx_block_id = tx_block_cursor.seek_exact(body.last_tx_num())?.map(|(_,b)| b);
                        if body.tx_count == 0 {
                            assert_ne!(tx_block_id,Some(number));
                        } else {
                            assert_eq!(tx_block_id, Some(number));
                        }

                        for tx_id in body.tx_num_range() {
                            assert!(static_file_provider.transaction_by_id(tx_id)?.is_some(), "Transaction is missing.");
                        }

                        prev_number = Some(number);
                    }
                    Ok(())
                })?;
                Ok(())
            }

            pub(crate) fn take_responses(&mut self) {
                self.responses.take();
            }

            pub(crate) fn commit(&self) {
                self.db.factory.static_file_provider().commit().unwrap();
            }
        }

        #[derive(Clone)]
        pub(crate) struct StubResponses(Vec<(Header, BlockBody<TransactionSigned>)>);

        impl EraStreamFactory<Header, BlockBody<TransactionSigned>> for StubResponses {
            fn create(
                self,
                _input: ExecInput,
            ) -> Result<ThreadSafeEraStream<Header, BlockBody<TransactionSigned>>, StageError>
            {
                let stream = stream::iter(vec![self.0]);

                Ok(Box::new(Box::pin(stream.map(|meta| {
                    Ok(Box::new(meta.into_iter().map(Ok))
                        as Item<Header, BlockBody<TransactionSigned>>)
                }))))
            }
        }
    }

    stage_test_suite!(EraTestRunner, era);
}
